{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8230365523c9330a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Usage\n",
    "1. Install python dependencies\n",
    "```shell\n",
    "!pip install pypdf langchain unstructured transformers_stream_generator\n",
    "!pip install modelscope  nltk pydantic  tiktoken  llama-index\n",
    "```\n",
    "\n",
    "2. Download data files we need in this example\n",
    "```shell\n",
    "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/averaged_perceptron_tagger.zip\n",
    "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/punkt.zip\n",
    "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/xianjiaoda.md\n",
    "\n",
    "!mkdir -p /root/nltk_data/tokenizers\n",
    "!mkdir -p /root/nltk_data/taggers\n",
    "!cp /mnt/workspace/punkt.zip /root/nltk_data/tokenizers\n",
    "!cp /mnt/workspace/averaged_perceptron_tagger.zip /root/nltk_data/taggers\n",
    "!cd /root/nltk_data/tokenizers; unzip punkt.zip;\n",
    "!cd /root/nltk_data/taggers; unzip averaged_perceptron_tagger.zip;\n",
    "\n",
    "!mkdir -p /mnt/workspace/custom_data\n",
    "!mv /mnt/workspace/xianjiaoda.md /mnt/workspace/custom_data\n",
    "\n",
    "!cd /mnt/workspace\n",
    "``` \n",
    "\n",
    "3. Enjoy your QA AI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a407764-9392-48ae-9bed-8c73c9f76fbc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-16T08:58:56.323000Z",
     "iopub.status.busy": "2024-01-16T08:58:56.322690Z",
     "iopub.status.idle": "2024-01-16T08:59:57.862755Z",
     "shell.execute_reply": "2024-01-16T08:59:57.862041Z",
     "shell.execute_reply.started": "2024-01-16T08:58:56.322980Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install pypdf langchain unstructured transformers_stream_generator\n",
    "!pip install modelscope  nltk pydantic  tiktoken  llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "696c6b78-53e8-4135-8376-ce8902b7d79a",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-01-16T09:04:59.193375Z",
     "iopub.status.busy": "2024-01-16T09:04:59.193082Z",
     "iopub.status.idle": "2024-01-16T09:05:00.971449Z",
     "shell.execute_reply": "2024-01-16T09:05:00.970857Z",
     "shell.execute_reply.started": "2024-01-16T09:04:59.193357Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/averaged_perceptron_tagger.zip\n",
    "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/punkt.zip\n",
    "!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/xianjiaoda.md\n",
    "\n",
    "!mkdir -p /root/nltk_data/tokenizers\n",
    "!mkdir -p /root/nltk_data/taggers\n",
    "!cp /mnt/workspace/punkt.zip /root/nltk_data/tokenizers\n",
    "!cp /mnt/workspace/averaged_perceptron_tagger.zip /root/nltk_data/taggers\n",
    "!cd /root/nltk_data/tokenizers; unzip punkt.zip;\n",
    "!cd /root/nltk_data/taggers; unzip averaged_perceptron_tagger.zip;\n",
    "\n",
    "!mkdir -p /mnt/workspace/custom_data\n",
    "!mv /mnt/workspace/xianjiaoda.md /mnt/workspace/custom_data\n",
    "\n",
    "!cd /mnt/workspace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1cb8feca-c71f-4ad6-8eff-caae95411aa0",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-01-16T09:06:03.024995Z",
     "iopub.status.busy": "2024-01-16T09:06:03.024622Z",
     "iopub.status.idle": "2024-01-16T09:09:15.894774Z",
     "shell.execute_reply": "2024-01-16T09:09:15.894230Z",
     "shell.execute_reply.started": "2024-01-16T09:06:03.024974Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading Model to directory: /mnt/workspace/.cache/modelscope/Qwen/Qwen-1_8B-Chat\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading Model to directory: /mnt/workspace/.cache/modelscope/Qwen/Qwen-1_8B-Chat\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Try importing flash-attention for faster inference...\n",
      "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n",
      "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
      "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "114cb3d66e9e4f6694ba66c91fc4b8a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-13 15:44:50,172 - modelscope.ast - INFO - Loading ast index from /mnt/workspace/.cache/modelscope/ast_indexer\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STEP1: qianwen LLM created\n",
      "STEP2: reading docs ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-13 15:44:50,565 - modelscope.ast - INFO - Updating the files for the changes of local files, first time updating will take longer time! Please wait till updating done!\n",
      "2025-01-13 15:44:50,588 - modelscope.ast - INFO - AST-Scanning the path \"/mnt/data/data/user/maoyunlin.myl/modelscope/package/modelscope\" with the following sub folders ['models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets', 'exporters']\n",
      "2025-01-13 15:44:58,393 - modelscope.ast - INFO - Scanning done! A number of 987 components indexed or updated! Time consumed 7.80418848991394s\n",
      "2025-01-13 15:44:58,770 - modelscope.ast - INFO - Loading done! Current index file version is 2.0.0, with md5 78f3257e373ec3a7089c4256261fb13f and a total number of 987 components indexed\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading Model to directory: /mnt/workspace/.cache/modelscope/damo/nlp_gte_sentence-embedding_chinese-small\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-13 15:44:59,577 - modelscope - INFO - Model revision not specified, using default: [master] version.\n",
      "2025-01-13 15:44:59,812 - modelscope - INFO - Got 9 files, start to download ...\n",
      "Processing 9 items:   0%|          | 0/9 [00:00<?, ?it/s]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "506bd9e27c154fe4aca6fe3bc364cb52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [config.json]:   0%|          | 0.00/772 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc7af37066694704bb07c85f5d9bcd6c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [tokenizer.json]:   0%|          | 0.00/291k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "042eed07f85b420bb8d908d3a9b001a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [pytorch_model.bin]:   0%|          | 0.00/57.7M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e20c9b7d9a0246f0bedaa75606e7094e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [configuration.json]:   0%|          | 0.00/2.02k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3a7ec13561049d997d86bbe3fb9143e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [special_tokens_map.json]:   0%|          | 0.00/125 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "496e78a215c54e21b334157c700c1522",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [README.md]:   0%|          | 0.00/14.0k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e9646c5f077462e9db0900939e9359c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [resources/dual-encoder.png]:   0%|          | 0.00/60.7k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "91a04e5faaef45fba21cd3644866fcc8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [tokenizer_config.json]:   0%|          | 0.00/425 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing 9 items:  11%|█         | 1/9 [00:00<00:02,  3.73it/s]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1668295ed6c141d88e6b0dddc66108d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading [vocab.txt]:   0%|          | 0.00/68.4k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing 9 items: 100%|██████████| 9/9 [00:03<00:00,  2.71it/s]\n",
      "2025-01-13 15:45:03,131 - modelscope - INFO - Download model 'damo/nlp_gte_sentence-embedding_chinese-small' successfully.\n",
      "2025-01-13 15:45:03,158 - modelscope - INFO - initiate model from /mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-small\n",
      "2025-01-13 15:45:03,159 - modelscope - INFO - initiate model from location /mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-small.\n",
      "2025-01-13 15:45:03,162 - modelscope - INFO - initialize model from /mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-small\n",
      "/root/miniconda3/envs/modelscope_1.21/lib/python3.9/site-packages/transformers/modeling_utils.py:488: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  return torch.load(checkpoint_file, map_location=map_location)\n",
      "2025-01-13 15:45:04,093 - modelscope - WARNING - No preprocessor field found in cfg.\n",
      "2025-01-13 15:45:04,094 - modelscope - WARNING - No val key and type key found in preprocessor domain of configuration.json file.\n",
      "2025-01-13 15:45:04,095 - modelscope - WARNING - Cannot find available config to build preprocessor at mode inference, current config: {'model_dir': '/mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-small'}. trying to build by task and model information.\n",
      "2025-01-13 15:45:04,136 - modelscope - WARNING - No preprocessor field found in cfg.\n",
      "2025-01-13 15:45:04,137 - modelscope - WARNING - No val key and type key found in preprocessor domain of configuration.json file.\n",
      "2025-01-13 15:45:04,137 - modelscope - WARNING - Cannot find available config to build preprocessor at mode inference, current config: {'model_dir': '/mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-small', 'sequence_length': 128}. trying to build by task and model information.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM is explicitly disabled. Using MockLLM.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/modelscope_1.21/lib/python3.9/site-packages/transformers/modeling_utils.py:909: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 2.2 reading doc done, vec db created.\n",
      "STEP3: chat prompt template created.\n",
      "@@@ query= 西安交大的校训是什么？\n",
      "@@@@ source= Node ID: 1392eeb8-1e5f-48fe-a31a-213e120fb5ba\n",
      "Text: 西安交通大学是我国最早兴办、享誉海内外的著名高等学府，是教育部直属重点大学。西迁以来，一代代交大人扎根西部、服务国家，为西部发展\n",
      "和国家建设作出了卓越贡献，以实际行动铸就了第一批纳入中国共产党人精神谱系的西迁精神。2017年12月，习近平总书记对学校15位老教授来信作出\n",
      "重要指示。在2018年新年贺词中，习近平总书记再次提到“西安交大西迁的老教授”。2020年4月22日，习近平总书记来校考察并发表重要讲话，强\n",
      "调西迁精神的核心是爱国主义，精髓是听党指挥跟党走，与党和国家、与民族和人民同呼吸、共命运，勉励师生在新时代创造属于我们这代人的历史功绩，给全\n",
      "校师生以巨大关怀和极大鼓舞，为学校新时代建设中国特色世界一流大学提供了根本遵循和行动指南。\n",
      "十九世纪末，甲午战败，民族危难。近代著名实业家、教育...\n",
      "Score:  0.916\n",
      "\n",
      "@@@@ source= Node ID: aa992526-59f5-4bb9-95b9-43291b821ce4\n",
      "Text: 2000年国务院决定将西安交通大学、西安医科大学、陕西财经学院三校合并，组建新的西安交通大学。\n",
      "学校是“七五”“八五”重点建设单位，首批进入国家“211”和“985”工程建设学校。2017 年入选国家一流大学建设名单 A\n",
      "类建设高校，2022 年入选国家第二轮“双一流”建设高校，8 个学科入选“双一流”建设学科。据 ESI 公布的数据，截至 2023 年 5\n",
      "月，学校 17 个学科进入世界学术机构前 1%，5 个学科进入前 1‰，其中工程学进入前万分之一。  学校是涵盖理、工、医、经、管、文、法、\n",
      "哲、艺、教育、交叉等11个学科门类的综合性研究型大学，设有32个学院（部、中心）、9个本科书院和3所直属附属医院。现有在编教工6635人，其\n",
      "中专任教师3789人。师资队伍中入选院士、杰青等国...\n",
      "Score:  0.874\n",
      "\n",
      "Human: 请基于```内的内容回答问题。\"\n",
      "```\n",
      "[Document(metadata={'file_path': '/mnt/workspace/custom_data/xianjiaoda.md', 'file_name': 'xianjiaoda.md', 'file_type': 'text/markdown', 'file_size': 13228, 'creation_date': '2025-01-13', 'last_modified_date': '2024-01-16'}, page_content='西安交通大学是我国最早兴办、享誉海内外的著名高等学府，是教育部直属重点大学。西迁以来，一代代交大人扎根西部、服务国家，为西部发展和国家建设作出了卓越贡献，以实际行动铸就了第一批纳入中国共产党人精神谱系的西迁精神。2017年12月，习近平总书记对学校15位老教授来信作出重要指示。在2018年新年贺词中，习近平总书记再次提到“西安交大西迁的老教授”。2020年4月22日，习近平总书记来校考察并发表重要讲话，强调西迁精神的核心是爱国主义，精髓是听党指挥跟党走，与党和国家、与民族和人民同呼吸、共命运，勉励师生在新时代创造属于我们这代人的历史功绩，给全校师生以巨大关怀和极大鼓舞，为学校新时代建设中国特色世界一流大学提供了根本遵循和行动指南。\\n\\n十九世纪末，甲午战败，民族危难。近代著名实业家、教育家盛宣怀秉持“自强首在储才，储才必先兴学”的信念，于1896年在上海创建了南洋公学，1921年定名为交通大学。学校坚持“求实学、务实业”办学宗旨，强调“修一等品行、求一等学问、创一等事业、成一等人才”办学目标。至二十世纪二三十年代，成为独具“理工管”特色的著名大学。抗战时期，学校移至租界，内迁重庆，坚持沪渝两地办学，为抵御外侮，不少学生投笔从戎，浴血沙场。解放前夕，师生积极投入民主革命和解放斗争，学校被誉为“民主堡垒”。\\n1955年中央决定交通大学内迁西安；1956年起师生分批迁赴西安；1957年分设为交通大学西安、上海两个部分，实行统一领导；1959年，交通大学西安部分定名为西安交通大学，同年被列为全国十六所重点大学之一。2000年国务院决定将西安交通大学、西安医科大学、陕西财经学院三校合并，组建新的西安交通大学。\\n\\n学校是“七五”“八五”重点建设单位，首批进入国家“211”和“985”工程建设学校。2017 年入选国家一流大学建设名单 A 类建设高校，2022 年入选国家第二轮“双一流”建设高校，8 个学科入选“双一流”建设学科。'), Document(metadata={'file_path': '/mnt/workspace/custom_data/xianjiaoda.md', 'file_name': 'xianjiaoda.md', 'file_type': 'text/markdown', 'file_size': 13228, 'creation_date': '2025-01-13', 'last_modified_date': '2024-01-16'}, page_content='2000年国务院决定将西安交通大学、西安医科大学、陕西财经学院三校合并，组建新的西安交通大学。\\n\\n学校是“七五”“八五”重点建设单位，首批进入国家“211”和“985”工程建设学校。2017 年入选国家一流大学建设名单 A 类建设高校，2022 年入选国家第二轮“双一流”建设高校，8 个学科入选“双一流”建设学科。据 ESI 公布的数据，截至 2023 年 5 月，学校 17 个学科进入世界学术机构前 1%，5 个学科进入前 1‰，其中工程学进入前万分之一。\\n\\n学校是涵盖理、工、医、经、管、文、法、哲、艺、教育、交叉等11个学科门类的综合性研究型大学，设有32个学院（部、中心）、9个本科书院和3所直属附属医院。现有在编教工6635人，其中专任教师3789人。师资队伍中入选院士、杰青等国家级各类重大人才工程545人次，获评国家级各类创新团队51个，为国家作出突出贡献并享受政府特殊津贴专家450名，国家级教学名师11名。\\n\\n学校现有学生54760名，其中本科生22407名，研究生29285名，留学生3068名；本科招生专业76个、博士学位授权一级学科36个、硕士学位授权一级学科43个、博士专业学位授权点6个、硕士专业学位授权点29个，博士后流动站30个，国家一级重点学科8个、国家二级重点学科8个、国家重点（培育）学科3个，全国（国家）重点实验室8个，国家工程（技术）研究中心10个，国家产教融合创新平台2个，国家国际科技合作基地5个，国家应用数学中心1个，2011协同创新中心1个、其他省部级及以上重点科研基地195个。\\n\\n建校127年来，形成了兴学强国、艰苦创业、崇德尚实、严谨治学的优良传统，起点高、基础厚、要求严、重实践的办学特色，培养出了一大批卓越的政治家、科学家、社会活动家、教育家、企业家、艺术家、医学专家等，如蔡锷、张元济、蔡元培、黄炎培、邵力子、李叔同、邹韬奋、陆定一、钱学森、张光斗、汪道涵、吴文俊、杨嘉墀、徐光宪、姚桐斌、陈能宽、江泽民、侯宗濂、黄旭华、顾诵芬、丁关根、吴自良、蒋新松、蒋正华、王希季、李金华、韩启德等。')]\n",
      "```\n",
      "我的问题是：西安交大的校训是什么？。\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/modelscope_1.21/lib/python3.9/site-packages/transformers/modeling_utils.py:909: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2.], device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\"求实学、务实业\"'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "from abc import ABC\n",
    "from typing import Any, List, Optional, Dict, cast\n",
    "\n",
    "import torch\n",
    "from langchain_core.language_models.llms import LLM\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from modelscope import AutoModelForCausalLM, AutoTokenizer\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings\n",
    "from llama_index.core.embeddings import BaseEmbedding\n",
    "from langchain_core.retrievers import BaseRetriever\n",
    "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
    "from langchain_core.documents import Document\n",
    "\n",
    "# configs for LLM\n",
    "llm_name = \"Qwen/Qwen-1_8B-Chat\"\n",
    "llm_revision = \"master\"\n",
    "\n",
    "# configs for embedding model\n",
    "embedding_model = \"damo/nlp_gte_sentence-embedding_chinese-small\"\n",
    "\n",
    "# file path for your custom knowledge base\n",
    "knowledge_doc_file_dir = \"/mnt/workspace/custom_data/\"\n",
    "knowledge_doc_file_path = knowledge_doc_file_dir + \"xianjiaoda.md\"\n",
    "\n",
    "\n",
    "# define our Embedding class to use models in Modelscope\n",
    "class ModelScopeEmbeddings4LlamaIndex(BaseEmbedding, ABC):\n",
    "    embed: Any = None\n",
    "    model_id: str = \"damo/nlp_gte_sentence-embedding_chinese-small\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            model_id: str,\n",
    "            **kwargs: Any,\n",
    "    ) -> None:\n",
    "        super().__init__(**kwargs)\n",
    "        try:\n",
    "            from modelscope.models import Model\n",
    "            from modelscope.pipelines import pipeline\n",
    "            from modelscope.utils.constant import Tasks\n",
    "            self.embed = pipeline(Tasks.sentence_embedding, model=self.model_id)\n",
    "\n",
    "        except ImportError as e:\n",
    "            raise ValueError(\n",
    "                \"Could not import some python packages.\" \"Please install it with `pip install modelscope`.\"\n",
    "            ) from e\n",
    "\n",
    "    def _get_query_embedding(self, query: str) -> List[float]:\n",
    "        text = query.replace(\"\\n\", \" \")\n",
    "        inputs = {\"source_sentence\": [text]}\n",
    "        return self.embed(input=inputs)['text_embedding'][0]\n",
    "\n",
    "    def _get_text_embedding(self, text: str) -> List[float]:\n",
    "        text = text.replace(\"\\n\", \" \")\n",
    "        inputs = {\"source_sentence\": [text]}\n",
    "        return self.embed(input=inputs)['text_embedding'][0]\n",
    "\n",
    "    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:\n",
    "        texts = list(map(lambda x: x.replace(\"\\n\", \" \"), texts))\n",
    "        inputs = {\"source_sentence\": texts}\n",
    "        return self.embed(input=inputs)['text_embedding']\n",
    "\n",
    "    async def _aget_query_embedding(self, query: str) -> List[float]:\n",
    "        return self._get_query_embedding(query)\n",
    "\n",
    "\n",
    "# define our Retriever with llama-index to co-operate with Langchain\n",
    "# note that the 'LlamaIndexRetriever' defined in langchain-community.retrievers.llama_index.py\n",
    "# is no longer compatible with llamaIndex code right now.\n",
    "class LlamaIndexRetriever(BaseRetriever):\n",
    "    index: Any\n",
    "    \"\"\"LlamaIndex index to query.\"\"\"\n",
    "\n",
    "    def _get_relevant_documents(\n",
    "        self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
    "    ) -> List[Document]:\n",
    "        \"\"\"Get documents relevant for a query.\"\"\"\n",
    "        try:\n",
    "            from llama_index.core.indices.base import BaseIndex\n",
    "            from llama_index.core import Response\n",
    "        except ImportError:\n",
    "            raise ImportError(\n",
    "                \"You need to install `pip install llama-index` to use this retriever.\"\n",
    "            )\n",
    "        index = cast(BaseIndex, self.index)\n",
    "        print('@@@ query=', query)\n",
    "\n",
    "        response = index.as_query_engine().query(query)\n",
    "        response = cast(Response, response)\n",
    "        # parse source nodes\n",
    "        docs = []\n",
    "        for source_node in response.source_nodes:\n",
    "            print('@@@@ source=', source_node)\n",
    "            metadata = source_node.metadata or {}\n",
    "            docs.append(\n",
    "                Document(page_content=source_node.get_text(), metadata=metadata)\n",
    "            )\n",
    "        return docs\n",
    "\n",
    "def torch_gc():\n",
    "    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "    DEVICE = \"cuda\"\n",
    "    DEVICE_ID = \"0\"\n",
    "    CUDA_DEVICE = f\"{DEVICE}:{DEVICE_ID}\" if DEVICE_ID else DEVICE\n",
    "    a = torch.Tensor([1, 2])\n",
    "    a = a.cuda()\n",
    "    print(a)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        with torch.cuda.device(CUDA_DEVICE):\n",
    "            torch.cuda.empty_cache()\n",
    "            torch.cuda.ipc_collect()\n",
    "\n",
    "\n",
    "# global resources used by QianWenChatLLM (this is not a good practice)\n",
    "tokenizer = AutoTokenizer.from_pretrained(llm_name, revision=llm_revision, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(llm_name, revision=llm_revision, device_map=\"auto\",\n",
    "                                             trust_remote_code=True, fp16=True).eval()\n",
    "\n",
    "\n",
    "# define QianWen LLM based on langchain's LLM to use models in Modelscope\n",
    "class QianWenChatLLM(LLM):\n",
    "    max_length: int = 10000\n",
    "    temperature: float = 0.01\n",
    "    top_p: float = 0.9\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    @property\n",
    "    def _llm_type(self):\n",
    "        return \"ChatLLM\"\n",
    "\n",
    "    def _call(\n",
    "            self,\n",
    "            prompt: str,\n",
    "            stop: Optional[List[str]] = None,\n",
    "            run_manager=None,\n",
    "            **kwargs: Any,\n",
    "    ) -> str:\n",
    "        print(prompt)\n",
    "        response, history = model.chat(tokenizer, prompt, history=None)\n",
    "        torch_gc()\n",
    "        return response\n",
    "\n",
    "\n",
    "# STEP1: create LLM instance\n",
    "qwllm = QianWenChatLLM()\n",
    "print('STEP1: qianwen LLM created')\n",
    "\n",
    "# STEP2: load knowledge file and initialize vector db by llamaIndex\n",
    "print('STEP2: reading docs ...')\n",
    "embeddings = ModelScopeEmbeddings4LlamaIndex(model_id=embedding_model)\n",
    "Settings.llm = None\n",
    "Settings.embed_model=embeddings  # global config, not good\n",
    "\n",
    "llamaIndex_docs = SimpleDirectoryReader(knowledge_doc_file_dir).load_data()\n",
    "llamaIndex_index = VectorStoreIndex.from_documents(llamaIndex_docs, chunk_size=512)\n",
    "retriever = LlamaIndexRetriever(index=llamaIndex_index)\n",
    "print(' 2.2 reading doc done, vec db created.')\n",
    "\n",
    "# STEP3: create chat template\n",
    "prompt_template = \"\"\"请基于```内的内容回答问题。\"\n",
    "```\n",
    "{context}\n",
    "```\n",
    "我的问题是：{question}。\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template=prompt_template)\n",
    "print('STEP3: chat prompt template created.')\n",
    "\n",
    "# STEP4: create RAG chain to do QA\n",
    "chain = (\n",
    "        {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
    "        | prompt\n",
    "        | qwllm\n",
    "        | StrOutputParser()\n",
    ")\n",
    "chain.invoke('西安交大的校训是什么？')\n",
    "# chain.invoke('魔搭社区有哪些模型?')\n",
    "# chain.invoke('modelscope是什么?')\n",
    "# chain.invoke('萧峰和乔峰是什么关系?')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
